// NOTE: n_wrong is accumulated per-block, then atomically added to global storage.
// Since atomicAdd() is supported only for float and not double, always use a float.
template <typename T>
__device__ void binary_accuracy_forward(const T * const __restrict__ pred,
                                        const T * const __restrict__ labl,
                                        const int num, const T threshold,
                                          float * const __restrict__ n_wrong) {
  __shared__ float local_n_wrong[THREADS_PER_BLOCK_X];

  const int idx = threadIdx.x + blockIdx.x * blockDim.x;
  if (idx >= num) {
    // out of bounds, set local loss to zero so that we can safely accumulate later
    local_n_wrong[threadIdx.x] = 0.f;
  } else {
    local_n_wrong[threadIdx.x] = (pred[idx]>threshold) ^ (labl[idx]>threshold);
  }

  __syncthreads();
  if (0 == threadIdx.x) {
    // thread 0 does in-block accumulation
    float total_local_n_wrong = 0.f;
    for (int i = 0; i < THREADS_PER_BLOCK_X; ++i)
      total_local_n_wrong += local_n_wrong[i];
    atomicAdd(n_wrong, total_local_n_wrong);
  }
}

extern "C" {
  __global__ void binary_accuracy_forward_float(const float * const __restrict__ pred,
                                                const float * const __restrict__ labl,
                                                const int num, float threshold,
                                                      float * const __restrict__ n_wrong) {
    binary_accuracy_forward(pred, labl, num, threshold, n_wrong);
  }
  __global__ void binary_accuracy_forward_double(const double * const __restrict__ pred,
                                                 const double * const __restrict__ labl,
                                                 const int num, double threshold,
                                                        float * const __restrict__ n_wrong) {
    binary_accuracy_forward(pred, labl, num, threshold, n_wrong);
  }
}

// vim: ft=cuda
